Skip to content

[AWQ] speed improvements#2188

Closed
HDCharles wants to merge 8 commits intomainfrom
98_awq_followups
Closed

[AWQ] speed improvements#2188
HDCharles wants to merge 8 commits intomainfrom
98_awq_followups

Conversation

@HDCharles
Copy link
Copy Markdown
Collaborator

@HDCharles HDCharles commented Jan 5, 2026

SUMMARY:
We identified several fixes, TODOs and improvements after the AWQ generalization PR to increase the AWQ speed. This largely implements them, details below.

speedup on:
python /home/HDCharles/repos/llm-compressor/examples/awq/llama_example.py
OLD:
(8.00 minutes)
GPU Memory - Peak: 10.00 GB
NOW:
(7.09 minutes)
GPU Memory - Peak: 13.18 GB

RESULT:
11.37% speedup, memory increase is expected and primarily due to change #4 and #1 below

changes:

  1. instead of recording the fp16_baseline_output during apply_smoothing, we add a hook so that the output is captured during sequential pipeline execution. also keep it on device to avoid unnecessary device on/offloading
  2. we concatenate outputs into a single tensor for faster error calculation
  3. instead of recording the entire state dict during compute_best_scale, we only record the state of the balance layers (also keep them onloaded on gpu instead of offloading since we're storing significantly less now)
  4. previously we would write the stored value to the balance layer, then update that value based on the scale factor (2 writes), now we calculate hte scaled balance layer and update directly (1 write)
  5. don't need to update the offload parameter when calculating best scale, only local
  6. improvement Add FP8 Support #4 also allows us to be more targeted, previously during the first write we would update the whole state dict which is no longer necessary

other changes which were tested:
torch compiling the best_scales_loop (device offloading prevented compilation)
calculating mse_loss progressively as each sample is run (slower)

TEST PLAN:
ran AWQ tests and examples to verify correctness

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @HDCharles, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces several key optimizations to the AWQ quantization process, primarily focusing on improving execution speed. By refining how FP16 baseline outputs are cached, streamlining error calculation, and making state management more targeted during scale computation, the changes aim to reduce overall processing time. The reported impact is a reduction in quantization time for a Llama example from 7.96 minutes to 6.93 minutes.

Highlights

  • Optimized FP16 Baseline Output Caching: Instead of recomputing FP16 baseline outputs during apply_smoothing, a hook is now used to capture these outputs during the initial sequential pipeline execution, storing them in a dedicated cache for later retrieval.
  • Faster Error Calculation: Outputs are now concatenated into a single tensor for Mean Squared Error (MSE) calculation, significantly speeding up the error computation process by performing it in a single operation rather than batch by batch.
  • Targeted State Management: During compute_best_scale, only the state of the relevant balance layers is recorded and restored, rather than the entire parent module's state dictionary, reducing overhead and improving efficiency.
  • Streamlined Weight Scaling: The process of applying scale factors to weights has been optimized from two write operations (storing then updating) to a single direct calculation and update, improving efficiency and precision.
  • Reduced Offload Parameter Updates: The need to update offload parameters during the best scale calculation has been eliminated, as weight updates are now handled locally and more precisely within the balance layers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several well-reasoned performance improvements to the AWQ modifier, focusing on optimizing data handling and reducing redundant computations. The changes, such as caching FP16 baseline outputs and concatenating tensors for faster loss calculation, are effective and result in a significant speedup. My review has identified one critical issue where model weights are not correctly restored after the grid search for scaling factors, which could lead to an incorrect model state. I have provided a code suggestion to address this. Additionally, I've included a medium-severity suggestion to improve code clarity. Overall, this is a great step forward in optimizing the AWQ implementation.

kylesayrs
kylesayrs previously approved these changes Jan 5, 2026
Copy link
Copy Markdown
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, awesome improvements

update_offload_parameter(
balance_layer,
"weight",
balance_layer.weight.data = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job avoiding writing to the offload

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need update_offload_parameter here because it all happens on the exec device, and the smooth function is done elsewhere after best_scales are calculated?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can just keep in memory and not mess with that.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Jan 5, 2026

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@HDCharles HDCharles added enhancement New feature or request ready When a PR is ready for review awq For any issue / PR related to AWQ support labels Jan 6, 2026
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
old: (7.96 minutes)
now: (6.93 minutes)

meta llama 3-8b example

Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! One clarifying question and a question on the increased memory requirements

update_offload_parameter(
balance_layer,
"weight",
balance_layer.weight.data = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need update_offload_parameter here because it all happens on the exec device, and the smooth function is done elsewhere after best_scales are calculated?

values = inspect.signature(module.forward).bind(*args, **kwargs)
self._parent_args_cache[module].append(values.arguments)

def cache_fp16_baseline_hook(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are now caching output activations for every mapping in a given subgraph, wouldn't this increase memory requirements quite a bit, especially for MoE models? For which model are you seeing the 30% memory increase that you mention in the summary?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hold on i'm rewriting this, i didn't realize by default we don't enable offloading so all my measurements were off. we do need to cache this but not on gpu and we can offload it

@dsikka
Copy link
Copy Markdown
Collaborator

dsikka commented Jan 14, 2026

@Mergifyio refresh

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 14, 2026

refresh

✅ Pull request refreshed

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 14, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @HDCharles.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 14, 2026
@dsikka dsikka mentioned this pull request Jan 19, 2026
25 tasks
@HDCharles
Copy link
Copy Markdown
Collaborator Author

see #2265 for further discussion and contains a better version of these changes

@HDCharles HDCharles closed this Jan 20, 2026
dsikka pushed a commit that referenced this pull request Jan 23, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Etelis pushed a commit to Etelis/llm-compressor that referenced this pull request Jan 24, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Etelis pushed a commit to Etelis/llm-compressor that referenced this pull request Jan 25, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.

# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
cajeonrh pushed a commit to cajeonrh/llm-compressor that referenced this pull request Feb 10, 2026
We identified a number of inefficiencies and fixes after the AWQ
generalization
[PR](vllm-project#1961), this PR
largely implements them, see details below. Note I previously made this
[speed improvements
PR](vllm-project#2188) which had
some issues that have been fixed in this one, that PR is going to be
closed.

# BENCHMARKS

to iterate more quickly i ran these tests on models with most of their
[layers
removed](https://github.com/vllm-project/llm-compressor/blob/1f036248f0310b8e95d488fc5d20831bcc9b62b7/examples/awq/llama_example.py#L69),
the actual improvement should be a bit better since the layers which
were removed are where the improvement happens. To replicate these
numbers see the [first
commit](vllm-project@1f03624#diff-208ced55cba2d38b1bc9b03b5f79ae3483c0849cc708a6c1131c231aae5d4b3dR69)

| Runtime (min) | Improvement | PR    | Base  |
|---------------|-------------|-------|-------|
| llama_no_off  |        8.7% |  6.17 |  6.76 |
| llama_off     |        5.6% | 10.46 | 11.08 |
| moe_no_off*    |        2.8% |  6.67 |  6.86 |
| moe_off*       |        1.9% |  7.57 |  7.72 |

| Memory (GB)     |        |       |       |
|-----------------|--------|-------|-------|
| llama_PR_no_off |   7.8% |  9.22 |    10 |
| llama_PR_off    |  17.6% |  3.66 |  4.44 |
| moe_PR_no_off   |  -5.2% | 11.61 | 11.04 |
| moe_PR_off**      | -24.3% |  2.92 |  2.35 |

\*The actual speedup for MoE is going to be higher than this. These
numbers are for a single layer being quantized so the calibration
overhead is going to attenuate the gains.

\*\* This worsening of memory is due to the weights being cached on
device, the layernorm -> up + gate proj mapping has to cache the up +
gate linears for the entire MLP layer which is fairly large.


# SUMMARY:
changes:

- targetted weight cache, no offloading
- previously in compute_best_scale we would record the entire state dict
of the parent and store it on cpu
- now only record the balance layer weights and store those on device
since they are generally small
- reduce load/write/rewrite weights
- during grid search we have to repeatedly updaste the weight to use
scaled and fake quantized versions of the weight. Previously this was
done by writing the original value, calculating the scaled value and
then writing that (2 writes)
- we instead calculate the scaled value directly from the on-device
cached value and write it once
- fake quantize only on device weight
  - previously the on/offloaded balance layer weight was updated
  - we now just update the on-device value
- compute loss
  - we slightly optimize  compute loss to reduce device movement
- note a number of approaches to improve the loss computation were
attempted including
- progressively calculating the loss while running the samples to get
int_w_outputs to avoid storing the whole int_w_output on device (was
slower and seems to saves the same memory as deleting int_w_outputs
after its used)
- using torch.cat to combine the fp16 outputs and int_w_outputs into a
single tensors so we can do only a single mse calculation (torch.cat
briefly doubles memory usage so it created a significant memory
increase)
- avoiding torch.cat by preallocating a flat tensor of the needed size
and progressively storing chunks into it (slow)
- torch compile on compute loss and/or run samples (would likely speed
up the run samples code but the offloading framework doesn't work well
with it)
- del int_w_outputs, simply deleting this intermediate value after its
used saves a significant amount of memory
- change default offload device behavior, normally None, now by default
we check whether we use the default MoE Mapping and if so, offload to
cpu by default

TEST PLAN:
(see first commit tests)

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awq For any issue / PR related to AWQ support enhancement New feature or request needs-rebase ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants